from azfuse import File
import time
import json
import os


def merge_files(input_folder, num_chunks, output_file):
    idx = list(range(num_chunks))
    file_list = [f"{input_folder}/{num_chunks}_{i}.jsonl" for i in idx]
    for file in file_list:
        if not File.isfile(file):
            # wait
            time.sleep(10)
    to_prepare = [f for f in file_list]
    File.prepare(to_prepare)
    with File.open(output_file, "w") as out:
        for file in file_list:
            with File.open(file, "r") as f:
                lines = f.readlines()
                for l in lines:
                    out.write(l)


def merge_files_soft(input_folder, num_chunks, output_file):
    idx = list(range(num_chunks))
    file_list = [f"{input_folder}/{num_chunks}_{i}.jsonl" for i in idx]
    to_prepare = [f for f in file_list if File.isfile(f)]
    File.prepare(to_prepare)
    with File.open(output_file, "w") as out:
        for file in to_prepare:
            with File.open(file, "r") as f:
                lines = f.readlines()
                for l in lines:
                    out.write(l)


def unk_metrics(jsonl_file, rtune_not_sure=False):
    unk_answers = ["I don", "There is nothing", "There is no"]
    data = [json.loads(s.strip()) for s in File.open(jsonl_file, "r").readlines()]
    total = len(data)

    for d in data:
        d["refusal"] = False
        for unk_a in unk_answers:
            if d["text"].lower().startswith(unk_a.lower()):
                d["refusal"] = True
                break
        if rtune_not_sure:
            if "I am not sure." in d["text"]:
                d["refusal"] = True
                
    refusal_rate = sum([1 for d in data if d["refusal"]]) / total
    positive_refusal_rate = sum([1 for d in data if d["refusal"] and "remove_0" in d["question_id"]]) / total
    false_refusal_rate = sum([1 for d in data if d["refusal"] and "remove_0" not in d["question_id"]]) / total
    false_answer_rate = sum([1 for d in data if not d["refusal"] and "remove_0" in d["question_id"]]) / total
    positive_answer_rate = sum([1 for d in data  if not d["refusal"] and "remove_0" not in d["question_id"]]) / total
    return {"total": total, "refusal_rate": refusal_rate, "positive_refusal_rate": positive_refusal_rate/refusal_rate, "false_refusal_rate": false_refusal_rate/refusal_rate, "answer_rate": 1 - refusal_rate,"false_answer_rate": false_answer_rate/(1-refusal_rate), "positive_answer_rate": positive_answer_rate/(1-refusal_rate) }



def construct_rtune_dataset(pred_file, gt_file, output_file=None, rtune=True):
    pred_data = [json.loads(s.strip()) for s in File.open(pred_file)]
    gt_data = [json.loads(s.strip()) for s in File.open(gt_file)]
    print("pred_data:", len(pred_data))
    print("gt_data:", len(gt_data))
    qid2gt_ans = {str(d["question_id"])+d["text"]: d for d in gt_data}
    qid2pred_ans = {str(d["question_id"])+d["prompt"]: d for d in pred_data}
    print("pred_data:", len(qid2gt_ans))
    print("gt_data:", len(qid2pred_ans))
    from nltk.stem.lancaster import LancasterStemmer
    st = LancasterStemmer()
    wrong_qids = set()
    for qid, pred in qid2pred_ans.items():
        if qid in qid2gt_ans:
            pred = pred["text"].strip()
            gt_ans = qid2gt_ans[qid]["answer"].strip()
            if len(gt_ans) == 1 and gt_ans.lower() in ["a", "b", "c", "d", "e"]:
                pred = pred[0]
            pred_stem = st.stem(pred)
            gt_stem = st.stem(gt_ans)
            if pred_stem == gt_stem:
                continue
            else:
                print("pred:", pred)
                print("gt:", gt_ans)
                wrong_qids.add(qid)
    print("wrong_qids:", len(wrong_qids))
    missing_qids = set(qid2gt_ans.keys()) - set(qid2pred_ans.keys())
    print("missing_qids:", len(missing_qids))
    correct_qids = set(qid2gt_ans.keys()) - missing_qids - wrong_qids
    print("correct_qids:", len(correct_qids))

    sample_correct_qids = list(correct_qids)[:len(wrong_qids)]
    sample_wrong_qids = list(wrong_qids)
    if output_file:
        output_data = []
        for qid in sample_correct_qids:
            output_data.append(qid2gt_ans[qid])
        for qid in sample_wrong_qids:
            d = qid2gt_ans[qid]
            if not d["answer"].endswith("."):
                d["answer"] += "."
            if rtune:
                d["answer"] += " I am not sure."
            else:
                d["answer"] = "I don't know."
            output_data.append(d)
        with File.open(output_file, "w") as f:
            for d in output_data:
                f.write(json.dumps(d) + "\n")


def merge_rtune_files(folder="/models/LLaVA/llava-v1.5-13b/rtune/answers/"):
    to_merge = ["llava_v1_5_mix665k.shuffle.short.3.eval/merge.rtune.jsonl", "llava_v1_5_mix665k.shuffle.short.2.eval/merge.incomplete.rtune.jsonl", "llava_v1_5_mix665k.shuffle.short.1.eval/merge.incomplete.rtune.jsonl", "llava_v1_5_mix665k.shuffle.short.0.eval/merge.incomplete.rtune.jsonl"]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.rtune.jsonl", "w") as out:
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = f.readlines()
                for l in lines:
                    out.write(l)
    to_merge = ["llava_v1_5_mix665k.shuffle.short.3.eval/merge.idk.jsonl", "llava_v1_5_mix665k.shuffle.short.2.eval/merge.incomplete.idk.jsonl", "llava_v1_5_mix665k.shuffle.short.1.eval/merge.incomplete.idk.jsonl", "llava_v1_5_mix665k.shuffle.short.0.eval/merge.incomplete.idk.jsonl"]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.idk.jsonl", "w") as out:
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = f.readlines()
                for l in lines:
                    out.write(l)
    convert_llava_eval_format_to_train_format(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.rtune.jsonl", f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.rtune.train.json")
    convert_llava_eval_format_to_train_format(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.idk.jsonl", f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.idk.train.json")


def merge_gqa_idk_rtune_train_files(folder="models/LLaVA/llava-v1.5-13b/rtune/answers/gqa/", rtune_type="rtune"):
    to_merge = [f"train_lama_box_q.llava.5k.{split}.eval/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json" for split in range(11)]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json", "w") as out:
        all_data = []
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = json.load(f)
                all_data += lines
        json.dump(all_data, out)


def merge_docci_idk_rtune_train_files(folder="models/LLaVA/llava-v1.5-7b/rtune/answers/docci/", rtune_type="rtune"):
    to_merge = [f"ours_caption_based.{split}.eval/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json" for split in range(15)]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json", "w") as out:
        all_data = []
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = json.load(f)
                all_data += lines
        json.dump(all_data, out)


def merge_rtune_train_files(folder="models/LLaVA/llava-v1.5-7b/rtune/", rtune_type="rtune"):
    # to_merge = [f"answers/gqa/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json",
    #             f"ours+llava/merge.rtune.train.json"]
    # to_merge = [f"{folder}/{f}" for f in to_merge]
    # with File.open(f"{folder}/ours+llava+gqa_idk/merge.{rtune_type}.train.json", "w") as out:
    #     all_data = []
    #     for file in to_merge:
    #         with File.open(file, "r") as f:
    #             lines = json.load(f)
    #             all_data += lines
    #     json.dump(all_data, out)
    to_merge = [f"answers/docci/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json",
                f"answers/gqa/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json",
                f"ours_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json",]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/unk_v1+gqa_idk+docci_idk/merge.{rtune_type}.train.json", "w") as out:
        all_data = []
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = json.load(f)
                if "ours_unk_vqa_train" in file:
                    for d in lines:
                        if "lama-gpt4v_gen_q" not in d["image"]:
                            d["image"] = os.path.join("lama-gpt4v_gen_q", d["image"])
                        all_data.append(d)
                else:
                    all_data += lines
        json.dump(all_data, out)
    to_merge = [f"answers/gqa/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json",
                f"ours_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json",]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/unk_v1+gqa_idk/merge.{rtune_type}.train.json", "w") as out:
        all_data = []
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = json.load(f)
                if "ours_unk_vqa_train" in file:
                    for d in lines:
                        if "lama-gpt4v_gen_q" not in d["image"]:
                            d["image"] = os.path.join("lama-gpt4v_gen_q", d["image"])
                        all_data.append(d)
                else:
                    all_data += lines
        json.dump(all_data, out)
    to_merge = [f"answers/docci/models_Mistral_Mistral-7B-Instruct-v0.2.{rtune_type}.train.json",
                f"ours+llava+gqa_idk/merge.{rtune_type}.train.json"]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/unk_v1+gqa_idk+docci_idk+llava_data/merge.{rtune_type}.train.json", "w") as out:
        all_data = []
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = json.load(f)
                all_data += lines
        json.dump(all_data, out)


def construct_rtune_dataset_for_qwenvl(folder="/models/Qwen/Qwen-VL-Chat/rtune/answers"):
    pred_files = [f"{folder}/llava_v1_5_mix665k.shuffle.short.5k.{i}.eval/merge.jsonl" for i in range(16)]
    data_folder = "<DATA_FOLDER>"
    gt_files = [f"{data_folder}/llava_v1_5_mix665k.shuffle.short.5k.{i}.eval.jsonl" for i in range(16)]
    output_files = [f"{folder}/llava_v1_5_mix665k.shuffle.short.5k.{i}.eval/merge.rtune.jsonl" for i in range(16)]
    for pred_file, gt_file, output_file in zip(pred_files, gt_files, output_files):
        construct_rtune_dataset(pred_file, gt_file, output_file, rtune=True)
        construct_rtune_dataset(pred_file, gt_file, output_file.replace(".rtune.", ".idk."), rtune=False)
    to_merge = [f"llava_v1_5_mix665k.shuffle.short.5k.{i}.eval/merge.rtune.jsonl" for i in range(16)]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.rtune.jsonl", "w") as out:
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = f.readlines()
                for l in lines:
                    out.write(l)
    to_merge = [f"llava_v1_5_mix665k.shuffle.short.5k.{i}.eval/merge.idk.jsonl" for i in range(16)]
    to_merge = [f"{folder}/{f}" for f in to_merge]
    with File.open(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.idk.jsonl", "w") as out:
        for file in to_merge:
            with File.open(file, "r") as f:
                lines = f.readlines()
                for l in lines:
                    out.write(l)
    convert_llava_eval_format_to_qwenvl_train_format(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.rtune.jsonl", f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.rtune.train.json", image_folder=data_folder)
    convert_llava_eval_format_to_qwenvl_train_format(f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.idk.jsonl", f"{folder}/llava_v1_5_mix665k.shuffle.short.eval/merge.idk.train.json", image_folder=data_folder)


def is_question_answerable(data, assume_answerable=True):
    if "remove_0" in data["question_id"]: # a hardcode for our unk questions
        return 0
    elif "answerable" in data:
        return data["answerable"]
    elif "category" in data:
        if data["category"] == "unk":
            return 0
        else:
            return 1
    elif "answer_type" in data:
        if data["answer_type"] == "unanswerable":
            return 0
        else:
            return 1
    elif "question_type" in data:
        if data["question_type"] == "adversarial":
            return 0
        elif data["question_type"] == "absurd":
            return 0
        else:
            return 1
    elif assume_answerable:
        return 1
    else:
        return None


def construct_rtune_dataset_based_on_lave_metrics(pred_file, gt_file, eval_model_id, rtune_idk=False, image_folder=None):
    '''
    SPLIT=0 && python -m llava.eval.utils construct_rtune_dataset_based_on_lave_metrics --pred_file /models/LLaVA/llava-v1.5-7b/rtune/answers/gqa/train_lama_box_q.llava.5k.${SPLIT}.eval/merge.jsonl --gt_file  /<DATA_FOLDER>gqa/train_lama_box_q.llava.5k.${SPLIT}.eval.jsonl --eval_model_id models/Mistral/Mistral-7B-Instruct-v0.2
    '''
    from tqdm import tqdm
    import os
    preds = [json.loads(s.strip()) for s in File.open(pred_file, "r").readlines()]
    gts = [json.loads(s.strip()) for s in File.open(gt_file, "r").readlines()]
    acc_output_file = os.path.join(os.path.dirname(pred_file), f"{eval_model_id.replace('/', '_')}_lave_output.jsonl")
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output_file = os.path.join(os.path.dirname(pred_file), f"{eval_model_id.replace('/', '_')}_refusal_lave_output.jsonl")
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    output_file = f"{eval_model_id.replace('/', '_')}.rtune.jsonl" if not rtune_idk else  f"{eval_model_id.replace('/', '_')}.idk.jsonl" 
    output_file = os.path.join(os.path.dirname(pred_file), output_file)
    num_converted_refusal = 0
    num_converted_answer = 0
    outputs = []
    for pred_d, gt_d, acc_d, recall_d in tqdm(zip(preds, gts, acc_output, recall_output)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        
        assert pred_d["question_id"] == recall_d["question_id"], f"Question id mismatch {pred_d['question_id']} vs {recall_d['question_id']}"
        if "question" not in recall_d or pred_d["prompt"] != recall_d["question"]:
            continue
        assert pred_d["prompt"] == recall_d["question"], f"Question mismatch {pred_d['question']} vs {recall_d['question']}"
        assert gt_d["question_id"] == recall_d["question_id"], f"Question id mismatch {gt_d['question_id']} vs {recall_d['question_id']}"
        assert gt_d["text"].replace("<image>\n", "") == recall_d["question"], f"Question mismatch {gt_d['question']} vs {recall_d['question']}"

        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        convert_answer = False
        if gt_refusal == 1:
            if recall_d["answer_refusal"] == 0:
                convert_answer = True
                num_converted_refusal += 1
            elif acc_d["acc"] == 0:
                convert_answer = True
                num_converted_refusal += 1
                
        else:
            if recall_d["answer_refusal"] > 0:
                convert_answer = True
                num_converted_answer += 1
            elif acc_d["acc"] == 0:
                convert_answer = True
                num_converted_answer += 1
        if convert_answer:
            if rtune_idk:
                gt_d["answer"] = "I don't know."
            else:
                gt_d["answer"] = gt_d["answer"] + " I am not sure."
        outputs.append(gt_d)
    print(f"Total: {len(outputs)}")
    print(f"Converted refusal: {num_converted_refusal}")
    print(f"Converted answer: {num_converted_answer}")
    with File.open(output_file, "w") as f:
        for d in outputs:
            f.write(json.dumps(d) + "\n")
    if "LLaVA" in pred_file:
        convert_llava_eval_format_to_train_format(output_file, output_file.replace(".jsonl", ".train.json"))
    else:
        assert image_folder is not None, "image_folder is required for Qwen-VL"
        convert_llava_eval_format_to_qwenvl_train_format(output_file, output_file.replace(".jsonl", ".train.json"), image_folder=image_folder)



def convert_llava_eval_format_to_train_format(josnl, output_json):
    data = [json.loads(s.strip()) for s in File.open(josnl, "r").readlines()]
    '''
    {
    #     "id": "000000157875",
    #     "image": "000000157875.jpg",
    #     "conversations": [
    #     {
    #         "from": "human",
    #         "value": "<image>\nWhat activity could develop the young girl's physical and cognitive abilities?"
    #     },
    #     {
    #         "from": "gpt",
    #         "value": "Flying a kite, like in the image, can be a fun activity that helps develop a young girl's physical and cognitive abilities. This activity encourages physical movement, such as running in open spaces, and helps improve hand-eye coordination as the child navigates the kite in the sky. Additionally, flying a kite requires problem-solving and strategic thinking, as the child must understand wind patterns and make adjustments to maintain the kite's flight. Overall, kite flying not only serves as a recreational activity but also contributes to the child's growth and development."
    #     }
    #     ],
    # },
    '''
    output_data = []
    for d in data:
        qid = d["question_id"]
        prompt = d["text"]
        answer = d["answer"]
        image = d["image"]
        
        output_data.append({"id": qid, "image": image, "conversations": [{"from": "human", "value": "<image>\n"+prompt}, {"from": "gpt", "value": answer}]})

    with File.open(output_json, "w") as f:
        json.dump(output_data, f)


def convert_llava_eval_format_to_qwenvl_train_format(josnl, output_json, image_folder):
    import os
    data = [json.loads(s.strip()) for s in File.open(josnl, "r").readlines()]
    output_data = []
    for d in data:
        qid = d["question_id"]
        prompt = d["text"]
        answer = d["answer"]
        image = d["image"]
        
        output_data.append({"id": qid,  "conversations": [{"from": "user", "value": f"Picture 1: <img>{os.path.join(image_folder, image)}</img>\n{prompt}"}, {"from": "assistant", "value": answer}]})

    with File.open(output_json, "w") as f:
        json.dump(output_data, f)
    


def check_output_file_exists(output_file, verbose=False):
    import os
    if File.isfile(output_file):
        # print(f"{output_file} exists")
        # return output_file
        if verbose:
            print(output_file)
        return output_file
    # elif check subfolder "answers" for the same file
    elif File.isfile(f"{os.path.dirname(output_file)}/answers/{os.path.basename(output_file)}"):
        # print(f"{os.path.dirname(output_file)}/answers/{output_file} exists")
        if verbose:
            print(f"{os.path.dirname(output_file)}/answers/{os.path.basename(output_file)}")
        return f"{os.path.dirname(output_file)}/answers/{os.path.basename(output_file)}"
    else:
        # print(f"{output_file} does not exist")
        if verbose:
            print("")
        return ""
        # return ""

def copy_file(src, dest):
    content = File.open(src, "r").read()
    with File.open(dest, "w") as f:
        f.write(content)


def merge_rtune_files():
    for split in ["rtune", "idk"]:
        # llava_data_rtune = f"/models/Qwen/Qwen-VL-Chat/rtune/answers/llava_v1_5_mix665k.shuffle.short.eval/merge.{split}.train.json"
        # our_data_rtune = f"/models/Qwen/Qwen-VL-Chat/rtune/ours_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0.2.{split}.train.json"
        # output_file = f"/models/Qwen/Qwen-VL-Chat/rtune/ours+llava/merge.{split}.train.json"
        # data = []
        # for file in [llava_data_rtune, our_data_rtune]:
        #     data += json.load(File.open(file, "r"))
        # with File.open(output_file, "w") as f:
        #     json.dump(data, f)

        llava_data_rtune = f"/models/LLaVA/llava-v1.5-7b/rtune/answers/llava_v1_5_mix665k.shuffle.short.eval//merge.{split}.train.json"
        our_data_rtune = f"/models/LLaVA/llava-v1.5-7b/rtune/ours_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0.2.{split}.train.json"
        output_file = f"/models/LLaVA/llava-v1.5-7b/rtune/ours+llava/merge.{split}.train.json"
        data = []
        data += json.load(File.open(llava_data_rtune, "r"))
        for d in json.load(File.open(our_data_rtune, "r")):
            if "lama-gpt4v_gen_q" not in d["image"]:
                d["image"] = os.path.join("lama-gpt4v_gen_q", d["image"])
            # print(d)
            data.append(d)
        with File.open(output_file, "w") as f:
            json.dump(data, f)
        output_data = json.load(File.open(output_file, "r"))
        length_list = []
        for sample in output_data:
            cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
            cur_len = cur_len if 'image' in sample else -cur_len
            length_list.append(cur_len)
        print(length_list[:10])


def format_recall_print(recall):
    # 'refusal': 0.2994265016601268, 'answer': 0.688198007847872, 'positive_refusal': 0.7908266129032258, 'false_refusal': 0.2091733870967742, 'positive_answer': 0.6263157894736842, 'false_answer': 0.3736842105263158,
    refusal = recall["refusal"]*100
    answer = recall["answer"]*100
    positive_refusal = recall["positive_refusal"]*100
    positive_refusal_dataset = recall["positive_refusal"]*recall["refusal"]*100
    false_refusal_dataset = recall["false_refusal"]*recall["refusal"]*100
    positive_answer = recall["positive_answer"]*100
    positive_answer_dataset = recall["positive_answer"]*recall["answer"]*100
    false_answer_dataset = recall["false_answer"]*recall["answer"]*100
    total_positive = positive_refusal_dataset + positive_answer_dataset
    # keys = ["refusal", "positive_refusal", "false_refusal", "answer", "positive_answer", "false_answer"]
    values = [f"{refusal:.2f}", f"{positive_refusal_dataset:.2f}", f"{positive_refusal:.2f}", f"{false_refusal_dataset:.2f}", f"{answer:.2f}",  f"{positive_answer_dataset:.2f}", f"{positive_answer:.2f}", f"{false_answer_dataset:.2f}", f"{total_positive:.2f}"]
    # header_row = ",".join(keys)
    value_row = ",".join(values)

    # Combine the rows with a line break
    # formatted_string = f"{header_row}\n{value_row}"
    formatted_string = f"{value_row}"
    return formatted_string
    # return f"\t\t\trefusal: {refusal:.2f}, positive_refusal: {positive_refusal:.2f}, false_refusal: {false_refusal:.2f}\n\t\t\tanswer: {answer:.2f}, positive_answer: {positive_answer:.2f}, false_answer: {false_answer:.2f}"




def format_recall_print_overall(recall):
    positive_refusal_dataset = recall["positive_refusal"]*100
    # positive_answer_dataset = recall["positive_answer"]*recall["answer"]*100
    # total_positive = positive_refusal_dataset + positive_answer_dataset

    values = [f"{positive_refusal_dataset:.2f}"]
    value_row = ",".join(values)

    formatted_string = f"{value_row}"
    return formatted_string


def format_f1_print_refusal(recall):
    overall = recall["counts"]
    num_of_predicted_refusal = overall["pred_refusal"]
    TP = overall["positive_refusal"]
    FP = num_of_predicted_refusal - TP
    TN = overall["positive_answer"]
    num_of_predicted_answer = overall["pred_answer"]
    FN = num_of_predicted_answer - TN
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1 = 2 * precision * recall / (precision + recall)
    return f"{f1*100:.2f}"


def format_calibration_print(calibration_score):
    if "brier_score" in calibration_score:
        score = calibration_score["brier_score"]
        return f"{score:.4f}"
    elif "brier_score_bins" in calibration_score:
        score = calibration_score["brier_score_bins"]
        return f"{score:.4f}"
    elif "score_ece" in calibration_score:
        score = calibration_score["score_ece"]
        score_max = calibration_score["score_max"]
        return f"{score:.4f}, {score_max:.4f}"


def format_calibration_print_overall(calibration_score):
    if "brier_score" in calibration_score:
        score = calibration_score["brier_score"]
        return f"{score:.4f}"
    elif "brier_score_bins" in calibration_score:
        score = calibration_score["brier_score_bins"]
        return f"{score:.4f}"
    elif "score_ece" in calibration_score:
        score = calibration_score["score_ece"]
        score_max = calibration_score["score_max"]
        return f"{score:.4f}"


def format_acc_print(acc, gt_not_yes_or_no=False, include_risk_coverage=False):
    all_acc = acc["all"]*100
    refusal_acc = acc["refusal"]*100
    answer_acc = acc["answer"]*100
    # return f"\t\t\tall: {all_acc:.2f}, refusal: {refusal_acc:.2f}, answer: {answer_acc:.2f}"
    # keys = ["refusal", "answer", "all"]
    values = [f"{refusal_acc:.2f}", f"{answer_acc:.2f}", f"{all_acc:.2f}"]
        
    

    if include_risk_coverage:
        risk = acc["risk"]*100
        coverage = acc["coverage"]*100
        risk_coverage = risk*coverage / 100
        values.append(f"{risk:.2f}")
        values.append(f"{coverage:.2f}")
        values.append(f"{risk_coverage:.2f}")
    if gt_not_yes_or_no:
        try:
            percent_gt_not_yes_or_no = acc["gt_not_yes_or_no"]*100
            values.append(f"{percent_gt_not_yes_or_no:.2f}")
        except:
            percent_pred_not_yes_or_no = acc["pred_not_yes_or_no"]*100
            values.append(f"{percent_pred_not_yes_or_no:.2f}")
    
    # header_row = ",".join(keys)
    value_row = ",".join(values)

    # Combine the rows with a line break
    # formatted_string = f"{header_row}\n{value_row}"
    formatted_string = f"{value_row}"
    return formatted_string

def format_acc_print_overall(acc):
    all_acc = acc["all"]*100
    values = [f"{all_acc:.2f}"]
    
    # header_row = ",".join(keys)
    value_row = ",".join(values)

    # Combine the rows with a line break
    # formatted_string = f"{header_row}\n{value_row}"
    formatted_string = f"{value_row}"
    return formatted_string

def check_debug_file_exists(output_file):
    if File.isfile(output_file):
        return output_file
    else:
        return ""
        # debug_file = os.path.splitext(output_file)[0] + ".debug" + os.path.splitext(output_file)[1]
        # if File.isfile(debug_file):
        #     return debug_file
        # else:
        #     return None

def gather_results(eval_model_id="models/Mistral/Mistral-7B-Instruct-v0.2", overwrite=False, clear_cache=False, calibration_only=False):
    model_paths =[
        # # "models/Qwen/Qwen-VL-Chat",
        # # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-_ep1/",
        # # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-idk_ep1",
        # #   "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-idk_ep5",
        # # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        "models/Qwen/Qwen-VL-Chat",
        "qwen-sft",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_lrv_with_chart_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-idk_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_llava_data-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk_v1+gqa_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_docci_idk-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk+gqa_idk+docci_idk-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk+docci_train_ep1/",
        "qwen-rtune",
        "<OUTPUT_FOLDER>/qwen-vl/rtune-qwen-vl-chat-lora-finetune_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_idk-qwen-vl-chat-lora-finetune_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_ous_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0_2-rtune-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_ous_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0_2-idk-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_ours+llava/rtune-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_ours+llava/idk-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk/rtune-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk/idk-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_docci_idk/rtune-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_docci_idk/idk-qwen-vl-chat-lora-finetune_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk/rtune-qwen-vl-chat-lora-finetune_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk/idk-qwen-vl-chat-lora-finetune_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk+llava_data/rtune-qwen-vl-chat-lora-finetune_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk+llava_data/idk-qwen-vl-chat-lora-finetune_ep1",
        "qwen-dpo",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_idk_only_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_unk_v1+gqa_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo+unk_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo+unk+gqa_idk_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_silkie_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_idk+silkie_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_unk_v1+gqa_unk+silkie_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_docci_idk_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_unk+gqa_idk+docci_train_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo+ours_image_based+docci_train_ep1",
        "<OUTPUT_FOLDER>/qwen-vl/dpo_unk+gqa_idk+ours_caption_based+silkie_train_ep1",
        "dummy",
        # "models/LLaVA/llava-v1.5-7b",
        #   "<OUTPUT_FOLDER>/llava/sft_llava-v1.5-7b-task-lora",
        #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora",
        #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep1",
        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #   "<OUTPUT_FOLDER>/llava/sft_debug_llava-v1.5-7b-task",
        #   "<OUTPUT_FOLDER>/llava/sft_llava-v1.5-7b-task-lora",

        "models/LLaVA/llava-v1.5-7b",
        "llava-sft",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_lrv_with_chart/",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_llava_data",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+llava_data+perturb_answerable",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa/",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+llava_data+perturb_answerable+gqa_idk",
        "<OUTPUT_FOLDER>/llava/sft_docci_idk_debug_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k",
        # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task+perturb_answerable_ep2",
        # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task+perturb_answerable_ep3",
        "llava_rtune",
        "<OUTPUT_FOLDER>/llava/rtune_llava-v1.5-7b-task-lora",
        "<OUTPUT_FOLDER>/llava/rtune_idk_llava-v1.5-7b-task-lora",
        "<OUTPUT_FOLDER>/llava/rtune_ours_unk_vqa/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora",
        "<OUTPUT_FOLDER>/llava/rtune_ours_unk_vqa/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora",
        "<OUTPUT_FOLDER>/llava/rtune_ours+llava/rtune_llava-v1.5-7b-task-lora",
        "<OUTPUT_FOLDER>/llava/rtune_ours+llava/idk_llava-v1.5-7b-task-lora",
        # "<OUTPUT_FOLDER>/llava/rtune_ours+llava+gqa_idk/rtune_llava-v1.5-7b-task-lora/",
        # "<OUTPUT_FOLDER>/llava/rtune_ours+llava+gqa_idk/idk_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk+llava_data/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
        "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk+llava_data/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
        # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-13b-task-lora",
        # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep3",
        # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep2",
        # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep1",# 
        "llava_dpo",
        "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_unk_only_lr2e-6_ep1_mmlr0_beta0.1",
        "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_unk_v1+gqa_idk_lr2e-6_ep3_mmlr0",
        "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_pope+desc_data_ep1/",
        "<OUTPUT_FOLDER>/llava/dpo_lora_pope+desc_data+unk_lr2e-6_ep1_mmlr0",
        "<OUTPUT_FOLDER>/llava/dpo_lora_pope+desc_data+unk+gqa_idk_ep1_mmlr0",
        "<OUTPUT_FOLDER>/llava/dpo_lora_silkie_lr2e-6_ep1_mmlr0",
        "<OUTPUT_FOLDER>/llava/dpo_lora_silkie+unk_lr2e-6_ep1_mmlr0_bsz2",
        "<OUTPUT_FOLDER>/llava/dpo_lora_silkie+unk+gqa_idk_ep1_mmlr0/",
        "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_docci_idk_only_lr2e-6_ep1_mmlr0_beta0.1/",
        "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_unk+gqa_idk+docci_train_lr2e-6_ep1_mmlr0/",
        "<OUTPUT_FOLDER>/llava/dpo_lora_unk+gqa_idk+ours_caption_based+pope+desc_data_ep1_mmlr0/",
        "<OUTPUT_FOLDER>/llava/dpo_lora_unk+gqa_idk+ours_caption_based+silkie_ep1_mmlr0",
        "llava_next",
        "models/LLaVA/llava-v1.5-13b",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
        "llava-v1.5-13b/sft-all",
        "models/LLaVA/llava-v1.6-vicuna-7b/",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-7b-task-lora_unk_v1+gqa+ours_caption_based/",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
        "models/LLaVA/llava-v1.6-vicuna-13b",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based/",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
        "models/LLaVA/llava-v1.6-34b",
        "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-34b-task-lora_unk_v1+gqa+ours_caption_based/",
        "llava-v1.6-34b/sft-all",
    ]
    datasets = [
        # # "unk_vqa",
        # # "unk_vqa_idk",
        # "unk_vqa_perturbed",
        "unk_vqa_validated",
        # "vizwiz_val",
        "vizwiz_val_prompt",
        "MMHal-Bench",
        "vqa_k_test_minus_llava_5k_prompt",
        "tdiuc_absurd_val_absurd_k_minus_llava_5k_prompt",
        "unk_vqa_public_val_k_minus_llava_remove_ans_type2",
        # # "gqa_single_object_idk",
        "gqa_idk",
        "docci_know_test.5k",
        "docci_pred_test.5k",
        "docci_ambiguity_test.5k",
        "docci_complex_test.5k"
    ]
    for dataset in datasets:
        print(f"===================={dataset}====================")
        for model_path in model_paths:
            output_file = f"/{model_path}/{dataset}/merge.jsonl"
            output_file = check_output_file_exists(output_file, verbose=False)
            if len(output_file) == 0:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                continue
            # get lave result files
            dir_name = os.path.dirname(output_file)
            lave_accuray_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_output.jsonl")
            lave_recall_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal_lave_output.jsonl")
            lave_acc_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_result.json")
            lave_recall_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal_lave_result.json")

            overall_res_file = f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_overall_result.json"
            overall_refusal_res_file = overall_res_file.replace("overall_result", "overall_refusal_result")
            # load recall res
            # print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
            print_str = ""
            if len(lave_recall_res_file) == 0:
            #     recall = json.loads(File.open(lave_recall_res_file, "r").read())
            #     print_str += format_recall_print(recall)+","
            # else:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                continue
            if clear_cache:
                File.clear_cache(overall_res_file)
                File.clear_cache(overall_refusal_res_file)
                File.clear_cache(lave_accuray_output)
                File.clear_cache(lave_recall_output)
            if overwrite or not File.isfile(overall_res_file) or not File.isfile(overall_refusal_res_file):
                from llava.eval.lave_metric import get_overall_lave_metrics, get_overall_lave_refusal_metrics
                if File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                    get_overall_lave_metrics(lave_accuray_output, lave_recall_output, overall_res_file)
                    get_overall_lave_refusal_metrics(lave_accuray_output, lave_recall_output, overall_refusal_res_file)
            else:
                overall_refusal = json.loads(File.open(overall_refusal_res_file, "r").read())
                if not calibration_only:
                    print_str += format_recall_print(overall_refusal)+","

                overall = json.loads(File.open(overall_res_file, "r").read())
                if not calibration_only:
                    print_str += format_acc_print(overall)+","

                
                
            conf_weighted_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_lave_conf_weighted_result.json")
                
            conf_weighted_refusal_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_lave_conf_weighted_reward_refusal_result.json")

            if clear_cache:
                File.clear_cache(conf_weighted_results_file)
                File.clear_cache(conf_weighted_refusal_results_file)
            if overwrite or not File.isfile(conf_weighted_results_file) or not File.isfile(conf_weighted_refusal_results_file):
                gt_prob_file = f"/{model_path}/{dataset}/gt_prob_only_yes_or_no/merge.jsonl"
                if not File.isfile(gt_prob_file):
                    gt_prob_file = f"/{model_path}/{dataset}/gt_probs_only_yes_or_no/merge.jsonl"
                if clear_cache:
                    File.clear_cache(gt_prob_file)
                # assert File.isfile(gt_prob_file), f"{gt_prob_file} does not exist"
                if File.isfile(gt_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                    from llava.eval.lave_metric import get_confidence_weighted_lave_metrics
                    get_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, gt_prob_file, conf_weighted_results_file, refusal_reward=False, debug=False)

                    get_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, gt_prob_file, conf_weighted_refusal_results_file, refusal_reward=True, debug=False)
            else:
                conf_weighted = json.loads(File.open(conf_weighted_results_file, "r").read())
                # print(f"conf_weighted, refusal_reward = False: \n {format_acc_print(conf_weighted)}")

                if not calibration_only:
                    print_str += format_acc_print(conf_weighted)+","

                conf_weighted = json.loads(File.open(conf_weighted_refusal_results_file, "r").read())
                # print(f"conf_weighted, refusal_reward = True: \n {format_acc_print(conf_weighted)}")

                if not calibration_only:
                    print_str += format_acc_print(conf_weighted, gt_not_yes_or_no=True, include_risk_coverage=False)+","
            
            conf_weighted_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_lave_conf_weighted_result.json")
                
            conf_weighted_refusal_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_lave_conf_weighted_reward_refusal_result.json")

            if clear_cache:
                File.clear_cache(conf_weighted_results_file)
                File.clear_cache(conf_weighted_refusal_results_file)
            if overwrite or not File.isfile(conf_weighted_results_file) or not File.isfile(conf_weighted_refusal_results_file):
                pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                if not File.isfile(pred_prob_file):
                    pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                if clear_cache:
                    File.clear_cache(pred_prob_file)
                # assert File.isfile(gt_prob_file), f"{gt_prob_file} does not exist"
                if File.isfile(pred_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                    from llava.eval.lave_metric import get_pred_confidence_weighted_lave_metrics
                    get_pred_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, conf_weighted_results_file, refusal_reward=False, debug=False)

                    get_pred_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, conf_weighted_refusal_results_file, refusal_reward=True, debug=False)
            else:
                conf_weighted = json.loads(File.open(conf_weighted_results_file, "r").read())
                # print(f"conf_weighted, refusal_reward = False: \n {format_acc_print(conf_weighted)}")

                if not calibration_only:
                    print_str += format_acc_print(conf_weighted)+","

                conf_weighted = json.loads(File.open(conf_weighted_refusal_results_file, "r").read())
                # print(f"conf_weighted, refusal_reward = True: \n {format_acc_print(conf_weighted)}")

                if not calibration_only:
                    print_str += format_acc_print(conf_weighted, gt_not_yes_or_no=True, include_risk_coverage=True)+","


            from llava.eval.lave_metric import get_calibration_fig
            pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
            if not File.isfile(pred_prob_file):
                pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
            if clear_cache:
                File.clear_cache(pred_prob_file)
            calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_answerable_calibration_curve.png")
            calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_answerable_calibration_score.json")
            calibration_score_bins_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
            if not File.isfile(calibration_results_file) or not File.isfile(calibration_score_file) or not File.isfile(calibration_score_bins_file) or overwrite:
                print(File.isfile(calibration_score_bins_file))
                get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
            else:
                calibration_score = json.loads(File.open(calibration_score_file, "r").read())
                print_str += format_calibration_print(calibration_score)+","
                calibration_score = json.loads(File.open(calibration_score_bins_file, "r").read())
                print_str += format_calibration_print(calibration_score)+","
            

            all_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_all_calibration_curve.png")
            all_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_all_calibration_score.json")
            all_calibration_score_bins_file = all_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
            if not File.isfile(all_calibration_results_file) or not File.isfile(all_calibration_score_file) or not File.isfile(all_calibration_score_bins_file):
                print(File.isfile(all_calibration_score_bins_file))
                get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
            else:
                all_calibration_score = json.loads(File.open(all_calibration_score_file, "r").read())
                print_str += format_calibration_print(all_calibration_score)+","
                all_calibration_score = json.loads(File.open(all_calibration_score_bins_file, "r").read())
                print_str += format_calibration_print(all_calibration_score)+","
            

            unk_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_unk_calibration_curve.png")
            unk_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_unk_calibration_score.json")
            unk_calibration_score_bins_file = unk_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
            if not File.isfile(unk_calibration_results_file) or not File.isfile(unk_calibration_score_file) or not File.isfile(unk_calibration_score_bins_file):
                print(File.isfile(unk_calibration_score_bins_file))
                get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
            else:
                unk_calibration_score = json.loads(File.open(unk_calibration_score_file, "r").read())
                print_str += format_calibration_print(unk_calibration_score)+","
                unk_calibration_score = json.loads(File.open(unk_calibration_score_bins_file, "r").read())
                print_str += format_calibration_print(unk_calibration_score)
                
            if len(print_str):
                print(print_str)
            else:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")


def gather_results_refusal01(eval_model_id="models/Mistral/Mistral-7B-Instruct-v0.2", overwrite=False, clear_cache=False, calibration_only=False, full_results=False):
    model_paths =[
    # # # "models/Qwen/Qwen-VL-Chat",
    # # # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-_ep1/",
    # # # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-idk_ep1",
    # # #   "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-idk_ep5",
    # # # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    # "models/Qwen/Qwen-VL-Chat",
    # "qwen-sft",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_lrv_with_chart_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune-idk_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_llava_data-_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data-_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk_v1+gqa_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_docci_idk-_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk+gqa_idk+docci_idk-_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk+docci_train_ep1/",
    # "qwen-rtune",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune-qwen-vl-chat-lora-finetune_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_idk-qwen-vl-chat-lora-finetune_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_ous_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0_2-rtune-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_ous_unk_vqa_train/models_Mistral_Mistral-7B-Instruct-v0_2-idk-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_ours+llava/rtune-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_ours+llava/idk-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk/rtune-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk/idk-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_docci_idk/rtune-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_docci_idk/idk-qwen-vl-chat-lora-finetune_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk/rtune-qwen-vl-chat-lora-finetune_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk/idk-qwen-vl-chat-lora-finetune_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk+llava_data/rtune-qwen-vl-chat-lora-finetune_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk+llava_data/idk-qwen-vl-chat-lora-finetune_ep1",
    # "qwen-dpo",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_idk_only_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_unk_v1+gqa_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo+unk_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo+unk+gqa_idk_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_silkie_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_idk+silkie_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_unk_v1+gqa_unk+silkie_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_docci_idk_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_unk+gqa_idk+docci_train_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo+ours_image_based+docci_train_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_unk+gqa_idk+ours_caption_based+silkie_train_ep1",
    # "dummy",
    # # "models/LLaVA/llava-v1.5-7b",
    # #   "<OUTPUT_FOLDER>/llava/sft_llava-v1.5-7b-task-lora",
    # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora",
    # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep1",
    # # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    # #   "<OUTPUT_FOLDER>/llava/sft_debug_llava-v1.5-7b-task",
    # #   "<OUTPUT_FOLDER>/llava/sft_llava-v1.5-7b-task-lora",

    # "models/LLaVA/llava-v1.5-7b",
    # "llava-sft",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_lrv_with_chart/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_llava_data",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+llava_data+perturb_answerable",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+llava_data+perturb_answerable+gqa_idk",
    # "<OUTPUT_FOLDER>/llava/sft_docci_idk_debug_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k",
    # # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task+perturb_answerable_ep2",
    # # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task+perturb_answerable_ep3",
    # "llava_rtune",
    # "<OUTPUT_FOLDER>/llava/rtune_llava-v1.5-7b-task-lora",
    # "<OUTPUT_FOLDER>/llava/rtune_idk_llava-v1.5-7b-task-lora",
    # "<OUTPUT_FOLDER>/llava/rtune_ours_unk_vqa/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora",
    # "<OUTPUT_FOLDER>/llava/rtune_ours_unk_vqa/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora",
    # "<OUTPUT_FOLDER>/llava/rtune_ours+llava/rtune_llava-v1.5-7b-task-lora",
    # "<OUTPUT_FOLDER>/llava/rtune_ours+llava/idk_llava-v1.5-7b-task-lora",
    # # "<OUTPUT_FOLDER>/llava/rtune_ours+llava+gqa_idk/rtune_llava-v1.5-7b-task-lora/",
    # # "<OUTPUT_FOLDER>/llava/rtune_ours+llava+gqa_idk/idk_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk+llava_data/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk+llava_data/models_Mistral_Mistral-7B-Instruct-v0.2-idk_llava-v1.5-7b-task-lora/",
    # # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-13b-task-lora",
    # # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep3",
    # # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep2",
    # # #   "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora+perturb_answerable_ep1",# 
    # "llava_dpo",
    # "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_unk_only_lr2e-6_ep1_mmlr0_beta0.1",
    # "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_unk_v1+gqa_idk_lr2e-6_ep3_mmlr0",
    # "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_pope+desc_data_ep1/",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_pope+desc_data+unk_lr2e-6_ep1_mmlr0",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_pope+desc_data+unk+gqa_idk_ep1_mmlr0",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_silkie_lr2e-6_ep1_mmlr0",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_silkie+unk_lr2e-6_ep1_mmlr0_bsz2",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_silkie+unk+gqa_idk_ep1_mmlr0/",
    # "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_docci_idk_only_lr2e-6_ep1_mmlr0_beta0.1/",
    # "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_unk+gqa_idk+docci_train_lr2e-6_ep1_mmlr0/",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_unk+gqa_idk+ours_caption_based+pope+desc_data_ep1_mmlr0/",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_unk+gqa_idk+ours_caption_based+silkie_ep1_mmlr0",
    # "llava_next",
    # "models/LLaVA/llava-v1.5-13b",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "llava-v1.5-13b/sft-all",
    # "models/LLaVA/llava-v1.6-vicuna-7b/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    # "models/LLaVA/llava-v1.6-vicuna-13b",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    # "models/LLaVA/llava-v1.6-34b",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-34b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "llava-v1.6-34b/sft-all",
    # "models/LLaVA/llava-v1.5-7b-lora",
    # "models/LLaVA/llava-v1.5-13b-lora",


    "models/Qwen/Qwen-VL-Chat",
    "qwen-sft",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_lrv_with_chart_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_llava_data-_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk+gqa_idk+docci_idk-_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk+docci_train_ep1/",

    # "qwen-rtune",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune-qwen-vl-chat-lora-finetune_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk/rtune-qwen-vl-chat-lora-finetune_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/rtune_unk_v1+gqa_idk+docci_idk+llava_data/rtune-qwen-vl-chat-lora-finetune_ep1",

    # "qwen-dpo",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_silkie_ep1",
    # # "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo_ep1/",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_unk+gqa_idk+docci_train_ep1/",
    # # "<OUTPUT_FOLDER>/qwen-vl/dpo_hadpo+ours_image_based+docci_train_ep1",
    # "<OUTPUT_FOLDER>/qwen-vl/dpo_unk+gqa_idk+ours_caption_based+silkie_train_ep1",
    "dummy",

    # "models/LLaVA/llava-v1.5-7b",
    # "llava-sft",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_lrv_with_chart/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_llava_data",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k",
    # "llava_rtune",
    # "<OUTPUT_FOLDER>/llava/rtune_llava-v1.5-7b-task-lora",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk+llava_data/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
    # "<OUTPUT_FOLDER>/llava/rtune_unk_v1+gqa_idk+docci_idk/models_Mistral_Mistral-7B-Instruct-v0.2-rtune_llava-v1.5-7b-task-lora/",
    # "llava-dpo",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_silkie_lr2e-6_ep1_mmlr0",
    # # "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_pope+desc_data_ep1",
    # "<OUTPUT_FOLDER>/llava/dpo_llava-v1.5-7b_lora_unk+gqa_idk+docci_train_lr2e-6_ep1_mmlr0/",
    # "<OUTPUT_FOLDER>/llava/dpo_lora_unk+gqa_idk+ours_caption_based+silkie_ep1_mmlr0/",
    # # "<OUTPUT_FOLDER>/llava/dpo_lora_unk+gqa_idk+ours_caption_based+pope+desc_data_ep1_mmlr0/",
    # "llava_next",
    # "models/LLaVA/llava-v1.5-13b",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "llava-v1.5-13b/sft-all",
    # "models/LLaVA/llava-v1.6-vicuna-7b/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    # "models/LLaVA/llava-v1.6-vicuna-13b",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    # "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    # "models/LLaVA/llava-v1.6-34b",
    # "llava-v1.6-34b/sft-unk",
    # "llava-v1.6-34b/sft-all",
    "llava-lora",
    "models/LLaVA/llava-v1.5-7b-lora",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k_new/",
    "models/LLaVA/llava-v1.5-13b-lora",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",

    # "<DATA_FOLDER>/gpt4v_output"

    ]
    datasets = [
        # # "unk_vqa",
        # # "unk_vqa_idk",
        # "unk_vqa_perturbed",
        # "unk_vqa_validated",
        # "vizwiz_val_short_prompt",
        # "vizwiz_val_prompt",
        # "MMHal-Bench",
        # "vqa_k_test_minus_llava_5k_prompt",
        # "tdiuc_absurd_val_absurd_k_minus_llava_5k_prompt",
        # "unk_vqa_public_val_k_minus_llava_remove_ans_type2",
        # # # "gqa_single_object_idk",
        # "gqa_idk",
        # "docci_know_test.5k",
        # "docci_pred_test.5k",
        # "docci_ambiguity_test.5k",
        # "docci_complex_test.5k"
        "tdiuc_absurd_val_absurd_k_minus_llava_5k",
        "unk_vqa_public_val_k_minus_llava_remove_ans_type2_no_prompt",
    ]
    for dataset in datasets:
        print(f"===================={dataset}====================")
        for model_path in model_paths:
            output_file = f"/{model_path}/{dataset}/merge.jsonl"
            output_file = check_output_file_exists(output_file, verbose=False)
            if len(output_file) == 0:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                continue
            # get lave result files
            dir_name = os.path.dirname(output_file)
            lave_accuray_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_output.jsonl")
            lave_recall_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal01_lave_output.jsonl")
            lave_acc_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_result.json")
            lave_recall_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal01_lave_result.json")

            overall_res_file = f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_overall_result_w_refusal_01.json"
            overall_refusal_res_file = overall_res_file.replace("overall_result", "overall_refusal_result")
            # load recall res
            # print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
            print_str = ""
            if len(lave_recall_res_file) == 0:
            #     recall = json.loads(File.open(lave_recall_res_file, "r").read())
            #     print_str += format_recall_print(recall)+","
            # else:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                continue
            if clear_cache:
                File.clear_cache(overall_res_file)
                File.clear_cache(overall_refusal_res_file)
                File.clear_cache(lave_accuray_output)
                File.clear_cache(lave_recall_output)
            if overwrite or not File.isfile(overall_res_file) or not File.isfile(overall_refusal_res_file):
                from llava.eval.lave_metric import get_overall_lave_metrics, get_overall_lave_refusal_metrics
                if File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                    get_overall_lave_metrics(lave_accuray_output, lave_recall_output, overall_res_file)
                    get_overall_lave_refusal_metrics(lave_accuray_output, lave_recall_output, overall_refusal_res_file)
            else:
                overall_refusal = json.loads(File.open(overall_refusal_res_file, "r").read())
                if not calibration_only:
                    if full_results:
                        print_str += format_recall_print(overall_refusal)+","
                    else:
                        print_str += format_f1_print_refusal(overall_refusal)+","

                overall = json.loads(File.open(overall_res_file, "r").read())
                if not calibration_only:
                    if full_results:
                        print_str += format_acc_print(overall)+","
                    else:
                        print_str += format_acc_print_overall(overall)+","
        
            pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
            if not File.isfile(pred_prob_file):
                pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
            if clear_cache:
                File.clear_cache(pred_prob_file)
            
            if not File.isfile(pred_prob_file):
                if len(print_str):
                    print(print_str)
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                continue
            
            conf_weighted_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_w_refusal01_lave_conf_weighted_result.json")
                
            conf_weighted_refusal_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_w_refusal01_lave_conf_weighted_reward_refusal_result.json")

            if clear_cache:
                File.clear_cache(conf_weighted_results_file)
                File.clear_cache(conf_weighted_refusal_results_file)
            if overwrite or not File.isfile(conf_weighted_results_file) or not File.isfile(conf_weighted_refusal_results_file):
                # assert File.isfile(gt_prob_file), f"{gt_prob_file} does not exist"
                if File.isfile(pred_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                    from llava.eval.lave_metric import get_pred_confidence_weighted_lave_metrics
                    get_pred_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, conf_weighted_results_file, refusal_reward=False, debug=False)

                    get_pred_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, conf_weighted_refusal_results_file, refusal_reward=True, debug=False)
            elif File.isfile(conf_weighted_results_file):
                conf_weighted = json.loads(File.open(conf_weighted_results_file, "r").read())
                # print(f"conf_weighted, refusal_reward = False: \n {format_acc_print(conf_weighted)}")

                if not calibration_only:
                    if full_results:
                        print_str += format_acc_print(conf_weighted)+","
                    else:
                        print_str += format_acc_print_overall(conf_weighted)+","

                conf_weighted = json.loads(File.open(conf_weighted_refusal_results_file, "r").read())
                # print(f"conf_weighted, refusal_reward = True: \n {format_acc_print(conf_weighted)}")

                if not calibration_only:
                    if full_results:
                        print_str += format_acc_print(conf_weighted, gt_not_yes_or_no=True, include_risk_coverage=True)+","


            from llava.eval.lave_metric import get_calibration_fig
            pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
            if not File.isfile(pred_prob_file):
                pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
            if clear_cache:
                File.clear_cache(pred_prob_file)
            calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_w_refusal01_answerable_calibration_curve.png")
            calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_w_refusal01_answerable_calibration_score.json")
            calibration_score_bins_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
            if not File.isfile(calibration_results_file) or not File.isfile(calibration_score_file) or not File.isfile(calibration_score_bins_file) or overwrite:
                print(File.isfile(calibration_score_bins_file))
                get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
            elif File.isfile(calibration_score_file) and File.isfile(calibration_score_bins_file):
                calibration_score = json.loads(File.open(calibration_score_file, "r").read())
                if full_results:
                    print_str += format_calibration_print(calibration_score)+","
                calibration_score = json.loads(File.open(calibration_score_bins_file, "r").read())
                if full_results:
                    print_str += format_calibration_print(calibration_score)+","
            

            all_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_w_refusal01_all_calibration_curve.png")
            all_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_w_refusal01_all_calibration_score.json")
            all_calibration_score_bins_file = all_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
            if not File.isfile(all_calibration_results_file) or not File.isfile(all_calibration_score_file) or not File.isfile(all_calibration_score_bins_file):
                print(File.isfile(all_calibration_score_bins_file))
                get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
            elif File.isfile(all_calibration_score_file) and File.isfile(all_calibration_score_bins_file):
                all_calibration_score = json.loads(File.open(all_calibration_score_file, "r").read())

                if full_results:
                    print_str += format_calibration_print(all_calibration_score)+","
                all_calibration_score = json.loads(File.open(all_calibration_score_bins_file, "r").read())

                if full_results:
                    print_str += format_calibration_print(all_calibration_score)+","
            

            unk_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_w_refusal01_unk_calibration_curve.png")
            unk_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_w_refusal01_unk_calibration_score.json")
            unk_calibration_score_bins_file = unk_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
            if not File.isfile(unk_calibration_results_file) or not File.isfile(unk_calibration_score_file) or not File.isfile(unk_calibration_score_bins_file):
                print(File.isfile(unk_calibration_score_bins_file))
                get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
            elif File.isfile(unk_calibration_score_file) and File.isfile(unk_calibration_score_bins_file):
                unk_calibration_score = json.loads(File.open(unk_calibration_score_file, "r").read())
                if full_results:
                    print_str += format_calibration_print(unk_calibration_score)+","
                unk_calibration_score = json.loads(File.open(unk_calibration_score_bins_file, "r").read())
                if full_results:
                    print_str += format_calibration_print(unk_calibration_score)
                else:
                    print_str += format_calibration_print_overall(unk_calibration_score)
                
            if len(print_str):
                print(print_str)
            else:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")


def gather_pope_results(dataset="pope", clear_cache=False):
    model_paths =[
    "models/Qwen/Qwen-VL-Chat",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_lrv_with_chart_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_llava_data-_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk+gqa_idk+docci_idk-_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk+docci_train_ep1/",
    "dummy",

    "models/LLaVA/llava-v1.5-7b",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_lrv_with_chart/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_llava_data",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k",
    "llava_next",
    "models/LLaVA/llava-v1.5-13b",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    "llava-v1.5-13b/sft-all",
    "models/LLaVA/llava-v1.6-vicuna-7b/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    "models/LLaVA/llava-v1.6-vicuna-13b",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    "models/LLaVA/llava-v1.6-34b",
    "llava-v1.6-34b/sft-unk",
    "llava-v1.6-34b/sft-all",
    "llava-lora",
    "models/LLaVA/llava-v1.5-7b-lora",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k_new/",
    "models/LLaVA/llava-v1.5-13b-lora",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    ]
    print(f"===================={dataset}====================")
    for model_path in model_paths:
        output_file = f"/{model_path}/{dataset}/answers/merge.jsonl"
        output_file = check_output_file_exists(output_file, verbose=False)
        if len(output_file) == 0:
            print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
            continue
        print_str = ""
        cat  = f"/{model_path}/{dataset}/answers/eval_result_adversarial.json"
        if clear_cache:
            File.clear_cache(cat)
        with File.open(cat, "r") as f:
            cat = json.load(f)
            print_str += f"{cat['f1']*100:.2f},"
        cat  = f"/{model_path}/{dataset}/answers/eval_result_popular.json"
        if clear_cache:
            File.clear_cache(cat)
        with File.open(cat, "r") as f:
            cat = json.load(f)
            print_str += f"{cat['f1']*100:.2f},"
        cat  = f"/{model_path}/{dataset}/answers/eval_result_random.json"
        if clear_cache:
            File.clear_cache(cat)
        with File.open(cat, "r") as f:
            cat = json.load(f)
            print_str += f"{cat['f1']*100:.2f},"
        cat  = f"/{model_path}/{dataset}/answers/eval_result_overall.json"
        if clear_cache:
            File.clear_cache(cat)
        with File.open(cat, "r") as f:
            cat = json.load(f)
            print_str += f"{cat['average_f1']*100:.2f}"
                
        if len(print_str):
            print(print_str)
        else:
            print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")


def gather_amber_results(dataset="amber", clear_cache=False):
    model_paths =[
    "models/Qwen/Qwen-VL-Chat",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_lrv_with_chart_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_llava_data-_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk+gqa_idk+docci_idk-_ep1/",
    "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk+docci_train_ep1/",
    "dummy",

    "models/LLaVA/llava-v1.5-7b",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_lrv_with_chart/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_llava_data",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k",
    "llava_next",
    "models/LLaVA/llava-v1.5-13b",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    "llava-v1.5-13b/sft-all",
    "models/LLaVA/llava-v1.6-vicuna-7b/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    "models/LLaVA/llava-v1.6-vicuna-13b",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/sft_idk_debug_llava-v1.6-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    "models/LLaVA/llava-v1.6-34b",
    "llava-v1.6-34b/sft-unk",
    "llava-v1.6-34b/sft-all",
    "llava-lora",
    "models/LLaVA/llava-v1.5-7b-lora",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k_new/",
    "models/LLaVA/llava-v1.5-13b-lora",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based/",
    "<OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-13b-task-lora_unk_v1+gqa+ours_caption_based+llava_v1_5_mix665k/",
    ]
    print(f"===================={dataset}====================")
    for model_path in model_paths:
        output_file = f"/{model_path}/{dataset}/answers/merge.jsonl"
        output_file = check_output_file_exists(output_file, verbose=False)
        if len(output_file) == 0:
            print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
            continue
        print_str = ""
        cat  = f"/{model_path}/{dataset}/answers/results.json"
        if clear_cache:
            File.clear_cache(cat)
        if not File.isfile(cat):
            print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
            continue
        with File.open(cat, "r") as f:
            cat = json.load(f)
            print_str += f"{cat['d']['F1']:.2f},"
            print_str += f"{cat['g']['CHAIR']:.2f}"
                
        if len(print_str):
            print(print_str)
        else:
            print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")

def plot_coverage_risk(
        coverage, risk, output_file
):
    if File.isfile(output_file):
        return
    import matplotlib.pyplot as plt
    # Plotting the calibration curve
    plt.figure(figsize=(8, 6))
    plt.plot(coverage, risk, "s-")
    plt.xlabel("coverage")
    plt.ylabel("risk")
    plt.savefig("./coverage_risk_curve.png", format="png")
    with File.open(output_file, "wb") as f:
        content = File.open("./coverage_risk_curve.png", "rb").read()
        f.write(content)


def gather_threshold_results(eval_model_id="models/Mistral/Mistral-7B-Instruct-v0.2", overwrite=False, calibration_only=False):
    model_paths =[
      "models/Qwen/Qwen-VL-Chat",
      "debug",
      "models/LLaVA/llava-v1.5-7b",
    ]
    datasets = [
        # "unk_vqa",
        # "unk_vqa_idk",
        # "unk_vqa_perturbed",
        "unk_vqa_validated",
        # "vizwiz_val",
        "vizwiz_val_prompt",
        "MMHal-Bench",
        "vqa_k_test_minus_llava_5k_prompt",
        "tdiuc_absurd_val_absurd_k_minus_llava_5k_prompt",
        "unk_vqa_public_val_k_minus_llava_remove_ans_type2",
        # "gqa_single_object_idk",
        "gqa_idk",
        # "docci_know_test.5k",
        # "docci_pred_test.5k",
        # "docci_ambiguity_test.5k",
        # "docci_complex_test.5k"
    ]
    for dataset in datasets:
        print(f"===================={dataset}====================")
        for model_path in model_paths:
            output_file = f"/{model_path}/{dataset}/merge.jsonl"
            output_file = check_output_file_exists(output_file, verbose=False)
            if len(output_file) == 0:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path} {pred_prob_threshold}")
                continue
            # get lave result files
            dir_name = os.path.dirname(output_file)
            lave_accuray_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_output.jsonl")
            lave_recall_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal_lave_output.jsonl")
            lave_acc_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_result.json")
            lave_recall_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal_lave_result.json")
            # load recall res
            coverage_list = []
            risk_list = []
            # for pred_prob_threshold in range(0, 11):
            for pred_prob_threshold in list(range(0, 100, 10)) + list(range(91, 101)):
                print_str = ""
                pred_prob_threshold = pred_prob_threshold / 100.
                overall_res_file = f"{dir_name}/{eval_model_id.replace('/', '_')}_thresh{pred_prob_threshold}_overall_result.json"
                overall_refusal_res_file = overall_res_file.replace("overall_result", "overall_refusal_result")
                if len(lave_recall_res_file) == 0:
                #     recall = json.loads(File.open(lave_recall_res_file, "r").read())
                #     print_str += format_recall_print(recall)+","
                # else:

                    print(f">>>>>>>>>>>>>>>>>>>>>>{model_path} {pred_prob_threshold}")
                    continue
                if overwrite or not File.isfile(overall_res_file) or not File.isfile(overall_refusal_res_file):
                    pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(pred_prob_file):
                        pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                    from llava.eval.threshold_baseline_metric import get_overall_threshold_metrics, get_overall_threshold_refusal_metrics
                    if File.isfile(pred_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                        get_overall_threshold_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, overall_res_file)
                        get_overall_threshold_refusal_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, overall_refusal_res_file)
                else:
                    overall_refusal = json.loads(File.open(overall_refusal_res_file, "r").read())
                    if not calibration_only:
                        print_str += format_recall_print(overall_refusal)+","

                    overall = json.loads(File.open(overall_res_file, "r").read())
                    if not calibration_only:
                        print_str += format_acc_print(overall)+","

                conf_weighted_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_conf_weighted_result.json")
                    
                conf_weighted_refusal_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_conf_weighted_reward_refusal_result.json")
                if overwrite or not File.isfile(conf_weighted_results_file) or not File.isfile(conf_weighted_refusal_results_file):
                    gt_prob_file = f"/{model_path}/{dataset}/gt_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(gt_prob_file):
                        gt_prob_file = f"/{model_path}/{dataset}/gt_probs_only_yes_or_no/merge.jsonl"
                    # assert File.isfile(gt_prob_file), f"{gt_prob_file} does not exist"
                    if File.isfile(gt_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                        from llava.eval.threshold_baseline_metric import get_confidence_weighted_metrics
                        get_confidence_weighted_metrics(lave_accuray_output, lave_recall_output, gt_prob_file, pred_prob_file, pred_prob_threshold, conf_weighted_results_file, refusal_reward=False, debug=False)

                        get_confidence_weighted_metrics(lave_accuray_output, lave_recall_output, gt_prob_file, pred_prob_file, pred_prob_threshold,conf_weighted_refusal_results_file, refusal_reward=True, debug=False)
                else:
                    conf_weighted = json.loads(File.open(conf_weighted_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = False: \n {format_acc_print(conf_weighted)}")
                    if not calibration_only:
                        print_str += format_acc_print(conf_weighted)+","

                    conf_weighted = json.loads(File.open(conf_weighted_refusal_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = True: \n {format_acc_print(conf_weighted)}")
                    if not calibration_only:
                        print_str += format_acc_print(conf_weighted, gt_not_yes_or_no=True, include_risk_coverage=False)+","
            
                conf_weighted_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_thresh{pred_prob_threshold}_lave_conf_weighted_result.json")
                    
                conf_weighted_refusal_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_thresh{pred_prob_threshold}_lave_conf_weighted_reward_refusal_result.json")

                if overwrite or not File.isfile(conf_weighted_results_file) or not File.isfile(conf_weighted_refusal_results_file):
                    pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(pred_prob_file):
                        pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                    # assert File.isfile(gt_prob_file), f"{gt_prob_file} does not exist"
                    if File.isfile(pred_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                        from llava.eval.threshold_baseline_metric import get_pred_confidence_weighted_metrics
                        get_pred_confidence_weighted_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, conf_weighted_results_file, refusal_reward=False, debug=False)

                        get_pred_confidence_weighted_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, conf_weighted_refusal_results_file, refusal_reward=True, debug=False)
                else:
                    conf_weighted = json.loads(File.open(conf_weighted_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = False: \n {format_acc_print(conf_weighted)}")

                    if not calibration_only:
                        print_str += format_acc_print(conf_weighted)+","

                    conf_weighted = json.loads(File.open(conf_weighted_refusal_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = True: \n {format_acc_print(conf_weighted)}")

                    if not calibration_only:
                        print_str += format_acc_print(conf_weighted, gt_not_yes_or_no=True, include_risk_coverage=True)+","
                    coverage_list.append(conf_weighted['coverage'])
                    risk_list.append(conf_weighted['risk'])

                if dataset in [
                    "unk_vqa_validated",
                    "vizwiz_val_prompt",
                    "MMHal-Bench",
                    "vqa_k_test_minus_llava_5k_prompt",
                    "unk_vqa_public_val_k_minus_llava_remove_ans_type2",
                    "gqa_single_object_idk",
                    "gqa_idk",
                    "docci_know_test.5k",
                    "docci_pred_test.5k",
                    "docci_ambiguity_test.5k",
                    "docci_complex_test.5k"
                ]:
                    from llava.eval.threshold_baseline_metric import get_calibration_fig
                    pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(pred_prob_file):
                        pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                    calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_answerable_calibration_curve.png")
                    calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_answerable_calibration_score.json")
                    calibration_score_bins_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                    if not File.isfile(calibration_results_file) or not File.isfile(calibration_score_file) or not File.isfile(calibration_score_bins_file) or overwrite:
                        get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, calibration_results_file)
                    else:
                        calibration_score = json.loads(File.open(calibration_score_file, "r").read())
                        print_str += format_calibration_print(calibration_score)+","
                        calibration_score = json.loads(File.open(calibration_score_bins_file, "r").read())
                        print_str += format_calibration_print(calibration_score)+","
            

                    all_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_all_calibration_curve.png")
                    all_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_all_calibration_score.json")
                    all_calibration_score_bins_file = all_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                    if not File.isfile(all_calibration_results_file) or not File.isfile(all_calibration_score_file) or not File.isfile(all_calibration_score_bins_file):
                        print(File.isfile(all_calibration_score_bins_file))
                        get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold,  calibration_results_file)
                    else:
                        all_calibration_score = json.loads(File.open(all_calibration_score_file, "r").read())
                        print_str += format_calibration_print(all_calibration_score)+","
                        all_calibration_score = json.loads(File.open(all_calibration_score_bins_file, "r").read())
                        print_str += format_calibration_print(all_calibration_score)+','
            

                    unk_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_unk_calibration_curve.png")
                    unk_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_unk_calibration_score.json")
                    unk_calibration_score_bins_file = unk_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                    if not File.isfile(unk_calibration_results_file) or not File.isfile(unk_calibration_score_bins_file) or not File.isfile(unk_calibration_score_bins_file):
                        print(File.isfile(unk_calibration_score_bins_file))
                        get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold,  calibration_results_file)
                    else:
                        unk_calibration_score = json.loads(File.open(unk_calibration_score_file, "r").read())
                        print_str += format_calibration_print(unk_calibration_score)+","
                        unk_calibration_score = json.loads(File.open(unk_calibration_score_bins_file, "r").read())
                        print_str += format_calibration_print(unk_calibration_score)
                if len(print_str):
                    print(print_str)
                else:
                    print(f">>>>>>>>>>>>>>>>>>>>>>{model_path} {pred_prob_threshold}")
                # print(print_str)
            coverage_risk_result_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh_coverage_vs_risk.png")
            plot_coverage_risk(coverage_list, risk_list, coverage_risk_result_file)


def gather_threshold_results_w_refusal01(eval_model_id="models/Mistral/Mistral-7B-Instruct-v0.2", overwrite=False, calibration_only=False, full_results=False):
    model_paths =[
      "models/Qwen/Qwen-VL-Chat",
      "debug",
      "models/LLaVA/llava-v1.5-7b",
      "llava-v1.5-7b-lora",
      "models/LLaVA/llava-v1.5-7b-lora",
      "llava-v1.5-13b-lora",
      "models/LLaVA/llava-v1.5-13b-lora",
    ]
    datasets = [
        # "unk_vqa",
        # "unk_vqa_idk",
        # "unk_vqa_perturbed",
        "unk_vqa_validated",
        # "vizwiz_val",
        # "vizwiz_val_prompt",
        # "MMHal-Bench",
        # "vqa_k_test_minus_llava_5k_prompt",
        # "tdiuc_absurd_val_absurd_k_minus_llava_5k_prompt",
        # "unk_vqa_public_val_k_minus_llava_remove_ans_type2",
        # "gqa_single_object_idk",
        # "gqa_idk",
        "docci_know_test.5k",
        "docci_pred_test.5k",
        "docci_ambiguity_test.5k",
        "docci_complex_test.5k"
    ]
    for dataset in datasets:
        print(f"===================={dataset}====================")
        for model_path in model_paths:
            output_file = f"/{model_path}/{dataset}/merge.jsonl"
            output_file = check_output_file_exists(output_file, verbose=False)
            if len(output_file) == 0:
                print(f">>>>>>>>>>>>>>>>>>>>>>{model_path} {pred_prob_threshold}")
                continue
            # get lave result files
            dir_name = os.path.dirname(output_file)
            lave_accuray_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_output.jsonl")
            lave_recall_output = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal01_lave_output.jsonl")
            lave_acc_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_result.json")
            lave_recall_res_file = check_debug_file_exists(f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal01_lave_result.json")
            # load recall res
            coverage_list = []
            risk_list = []
            # for pred_prob_threshold in range(0, 11):
            for pred_prob_threshold in list(range(0, 100, 10)) + list(range(91, 101)):
                print_str = ""
                pred_prob_threshold = pred_prob_threshold / 100.
                overall_res_file = f"{dir_name}/{eval_model_id.replace('/', '_')}_thresh{pred_prob_threshold}_overall_result_w_refusal01.json"
                overall_refusal_res_file = overall_res_file.replace("overall_result", "overall_refusal_result")
                if len(lave_recall_res_file) == 0:
                #     recall = json.loads(File.open(lave_recall_res_file, "r").read())
                #     print_str += format_recall_print(recall)+","
                # else:

                    print(f">>>>>>>>>>>>>>>>>>>>>>{model_path} {pred_prob_threshold}")
                    continue
                if overwrite or not File.isfile(overall_res_file) or not File.isfile(overall_refusal_res_file):
                    pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(pred_prob_file):
                        pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                    from llava.eval.threshold_baseline_metric import get_overall_threshold_metrics, get_overall_threshold_refusal_metrics
                    if File.isfile(pred_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                        get_overall_threshold_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, overall_res_file)
                        get_overall_threshold_refusal_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, overall_refusal_res_file)
                else:
                    overall_refusal = json.loads(File.open(overall_refusal_res_file, "r").read())
                    if not calibration_only:
                        if full_results:
                            print_str += format_recall_print(overall_refusal)+","
                        else:
                            print_str += format_f1_print_refusal(overall_refusal)+","

                    overall = json.loads(File.open(overall_res_file, "r").read())
                    if not calibration_only:
                        if full_results:
                            print_str += format_acc_print(overall)+","
                        else:
                            print_str += format_acc_print_overall(overall)+","
            
                conf_weighted_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_thresh{pred_prob_threshold}_w_refusal01_lave_conf_weighted_result.json")
                    
                conf_weighted_refusal_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_pred_probyn_thresh{pred_prob_threshold}_w_refusal01_lave_conf_weighted_reward_refusal_result.json")

                if overwrite or not File.isfile(conf_weighted_results_file) or not File.isfile(conf_weighted_refusal_results_file):
                    pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(pred_prob_file):
                        pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                    # assert File.isfile(gt_prob_file), f"{gt_prob_file} does not exist"
                    if File.isfile(pred_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                        from llava.eval.threshold_baseline_metric import get_pred_confidence_weighted_metrics
                        get_pred_confidence_weighted_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, conf_weighted_results_file, refusal_reward=False, debug=False)

                        get_pred_confidence_weighted_metrics(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, conf_weighted_refusal_results_file, refusal_reward=True, debug=False)
                else:
                    conf_weighted = json.loads(File.open(conf_weighted_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = False: \n {format_acc_print(conf_weighted)}")

                    if not calibration_only:
                        if full_results:
                            print_str += format_acc_print(conf_weighted)+","
                        else:
                            print_str += format_acc_print_overall(conf_weighted)+","

                    conf_weighted = json.loads(File.open(conf_weighted_refusal_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = True: \n {format_acc_print(conf_weighted)}")

                    if not calibration_only:
                        if full_results:
                            print_str += format_acc_print(conf_weighted, gt_not_yes_or_no=True, include_risk_coverage=True)+","
                    coverage_list.append(conf_weighted['coverage'])
                    risk_list.append(conf_weighted['risk'])

                if dataset in [
                    "unk_vqa_validated",
                    "vizwiz_val_prompt",
                    "MMHal-Bench",
                    "vqa_k_test_minus_llava_5k_prompt",
                    "unk_vqa_public_val_k_minus_llava_remove_ans_type2",
                    "gqa_single_object_idk",
                    "gqa_idk",
                    "docci_know_test.5k",
                    "docci_pred_test.5k",
                    "docci_ambiguity_test.5k",
                    "docci_complex_test.5k"
                ]:
                    from llava.eval.threshold_baseline_metric import get_calibration_fig
                    pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(pred_prob_file):
                        pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                    calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_w_refusal01_answerable_calibration_curve.png")
                    calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_w_refusal01_answerable_calibration_score.json")
                    calibration_score_bins_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                    if not File.isfile(calibration_results_file) or not File.isfile(calibration_score_file) or not File.isfile(calibration_score_bins_file) or overwrite:
                        get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold, calibration_results_file)
                    else:
                        calibration_score = json.loads(File.open(calibration_score_file, "r").read())
                        if full_results:
                            print_str += format_calibration_print(calibration_score)+","
                        
                        calibration_score = json.loads(File.open(calibration_score_bins_file, "r").read())
                        if full_results:
                            print_str += format_calibration_print(calibration_score)+","
            

                    all_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_w_refusal01_all_calibration_curve.png")
                    all_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_w_refusal01_all_calibration_score.json")
                    all_calibration_score_bins_file = all_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                    if not File.isfile(all_calibration_results_file) or not File.isfile(all_calibration_score_file) or not File.isfile(all_calibration_score_bins_file):
                        print(File.isfile(all_calibration_score_bins_file))
                        get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold,  calibration_results_file)
                    else:
                        all_calibration_score = json.loads(File.open(all_calibration_score_file, "r").read())

                        if full_results:
                            print_str += format_calibration_print(all_calibration_score)+","
                        all_calibration_score = json.loads(File.open(all_calibration_score_bins_file, "r").read())

                        if full_results:
                            print_str += format_calibration_print(all_calibration_score)+','
            

                    unk_calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_w_refusal01_unk_calibration_curve.png")
                    unk_calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_thresh{pred_prob_threshold}_lave_w_refusal01_unk_calibration_score.json")
                    unk_calibration_score_bins_file = unk_calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                    if not File.isfile(unk_calibration_results_file) or not File.isfile(unk_calibration_score_bins_file) or not File.isfile(unk_calibration_score_bins_file):
                        print(File.isfile(unk_calibration_score_bins_file))
                        get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, pred_prob_threshold,  calibration_results_file)
                    else:
                        unk_calibration_score = json.loads(File.open(unk_calibration_score_file, "r").read())
                        if full_results:
                            print_str += format_calibration_print(unk_calibration_score)+","
                        unk_calibration_score = json.loads(File.open(unk_calibration_score_bins_file, "r").read())
                        if full_results:
                            print_str += format_calibration_print(unk_calibration_score)
                        else:
                            print_str += format_calibration_print_overall(unk_calibration_score)
                if len(print_str):
                    print(print_str)
                else:
                    print(f">>>>>>>>>>>>>>>>>>>>>>{model_path} {pred_prob_threshold}")
                # print(print_str)
            coverage_risk_result_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_w_refusal01_thresh_coverage_vs_risk.png")
            plot_coverage_risk(coverage_list, risk_list, coverage_risk_result_file)



def gather_results_evaluator( overwrite=False):
    model_paths =[
    "models/Qwen/Qwen-VL-Chat",
    "models/LLaVA/llava-v1.5-7b",
    ]
    datasets = [
        "unk_vqa_validated_subset_100",
        # "MMHal-Bench",
    ]
    for dataset in datasets:
        print(f"===================={dataset}====================")
        for eval_model_id in [
                "models/Mistral/Mistral-7B-Instruct-v0.2",
                "models/01-ai/Yi-34B-Chat-4bits",
                "gpt4"]:
            print(f"<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<{eval_model_id}")
            for model_path in model_paths:
                output_file = f"/{model_path}/{dataset}/merge.jsonl"
                output_file = check_output_file_exists(output_file, verbose=False)
                if len(output_file) == 0:
                    print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                    continue
                # get lave result files
                dir_name = os.path.dirname(output_file)
                if "gpt4" in eval_model_id:
                    lave_accuray_output = f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_acc/lave_output.jsonl"
                    lave_recall_output = f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_refusal/refusal_lave_output.jsonl"
                else:
                    lave_accuray_output = f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_output.jsonl"
                    lave_recall_output = f"{dir_name}/{eval_model_id.replace('/', '_')}_refusal_lave_output.jsonl"

                overall_res_file = f"{dir_name}/{eval_model_id.replace('/', '_')}_lave_overall_result.json"
                overall_refusal_res_file = overall_res_file.replace("overall_result", "overall_refusal_result")
                # load recall res
                # print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                print_str = ""
                if not File.isfile(lave_recall_output):
                #     recall = json.loads(File.open(lave_recall_res_file, "r").read())
                #     print_str += format_recall_print(recall)+","
                # else:
                    print(f"not found: {lave_recall_output}")
                    print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")
                    continue
                if overwrite or not File.isfile(overall_res_file) or not File.isfile(overall_refusal_res_file):
                    from llava.eval.lave_metric import get_overall_lave_metrics, get_overall_lave_refusal_metrics
                    if File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                        get_overall_lave_metrics(lave_accuray_output, lave_recall_output, overall_res_file)
                        get_overall_lave_refusal_metrics(lave_accuray_output, lave_recall_output, overall_refusal_res_file)
                else:
                    overall_refusal = json.loads(File.open(overall_refusal_res_file, "r").read())
                    print_str += format_recall_print(overall_refusal)+","

                    overall = json.loads(File.open(overall_res_file, "r").read())
                    print_str += format_acc_print(overall)+","

                    
                    
                conf_weighted_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_lave_conf_weighted_result.json")
                    
                conf_weighted_refusal_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_probyn_lave_conf_weighted_reward_refusal_result.json")
                if overwrite or not File.isfile(conf_weighted_results_file) or not File.isfile(conf_weighted_refusal_results_file):
                    gt_prob_file = f"/{model_path}/{dataset}/gt_prob_only_yes_or_no/merge.jsonl"
                    if not File.isfile(gt_prob_file):
                        gt_prob_file = f"/{model_path}/{dataset}/gt_probs_only_yes_or_no/merge.jsonl"
                    # assert File.isfile(gt_prob_file), f"{gt_prob_file} does not exist"
                    if File.isfile(gt_prob_file) and File.isfile(lave_accuray_output) and File.isfile(lave_recall_output):
                        from llava.eval.lave_metric import get_confidence_weighted_lave_metrics
                        get_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, gt_prob_file, conf_weighted_results_file, refusal_reward=False, debug=False)

                        get_confidence_weighted_lave_metrics(lave_accuray_output, lave_recall_output, gt_prob_file, conf_weighted_refusal_results_file, refusal_reward=True, debug=False)
                else:
                    conf_weighted = json.loads(File.open(conf_weighted_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = False: \n {format_acc_print(conf_weighted)}")
                    print_str += format_acc_print(conf_weighted)+","

                    conf_weighted = json.loads(File.open(conf_weighted_refusal_results_file, "r").read())
                    # print(f"conf_weighted, refusal_reward = True: \n {format_acc_print(conf_weighted)}")
                    print_str += format_acc_print(conf_weighted, gt_not_yes_or_no=True, include_risk_coverage=True)+","
                from llava.eval.lave_metric import get_calibration_fig
                pred_prob_file = f"/{model_path}/{dataset}/pred_prob_only_yes_or_no/merge.jsonl"
                if not File.isfile(pred_prob_file):
                    pred_prob_file = f"/{model_path}/{dataset}/pred_probs_only_yes_or_no/merge.jsonl"
                calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_answerable_calibration_curve.png")
                calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_answerable_calibration_score.json")
                calibration_score_bins_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                if not File.isfile(calibration_results_file) or not File.isfile(calibration_score_file) or not File.isfile(calibration_score_bins_file) or overwrite:
                    get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
                else:
                    calibration_score = json.loads(File.open(calibration_score_file, "r").read())
                    print_str += format_calibration_print(calibration_score)
                

            

                calibration_results_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_all_calibration_curve.png")
                calibration_score_file = os.path.join(dir_name, f"{eval_model_id.replace('/', '_')}_lave_all_calibration_score.json")
                calibration_score_bins_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
                if not File.isfile(calibration_results_file) or not File.isfile(calibration_score_file) or not File.isfile(calibration_score_bins_file) or overwrite:
                    print(File.isfile(calibration_score_bins_file))
                    get_calibration_fig(lave_accuray_output, lave_recall_output, pred_prob_file, calibration_results_file)
                else:
                    calibration_score = json.loads(File.open(calibration_score_file, "r").read())
                    print_str += format_calibration_print(calibration_score)+","
                    calibration_score = json.loads(File.open(calibration_score_bins_file, "r").read())
                    print_str += format_calibration_print(calibration_score)
                if len(print_str):
                    print(print_str)
                else:
                    print(f">>>>>>>>>>>>>>>>>>>>>>{model_path}")



def compare_two_results():
    from tqdm import tqdm
    from collections import defaultdict
    models = [
        "models/Qwen/Qwen-VL-Chat",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_llava_data-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk+gqa_idk+docci_idk-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk+docci_train_ep1/",
    ]
    dataset = "amber"
    yes_rate = defaultdict(int) 
    no_rate = defaultdict(int)
    to_compre = []
    for model in models:
        total = 0
        output_file = f"/{model}/{dataset}/answers/merge.jsonl"
        answers = [json.loads(line) for line in File.open(output_file, "r").readlines()]
        qid2answers = {ans["question_id"]: ans for ans in answers if "Describe" not in ans["prompt"]}
        for qid, ans in qid2answers.items():
            if ans["text"].lower() == "yes":
                yes_rate[model] += 1
                total += 1
            elif ans["text"].lower() == "no":
                no_rate[model] += 1
                total += 1
        yes_rate[model] = yes_rate[model] / total
        no_rate[model] = no_rate[model] / total
        to_compre.append(qid2answers)
    
    print(f"yes_rate: {yes_rate}")
    print(f"no_rate: {no_rate}")

    print(f"finish loading....")

    for qid, ans in tqdm(to_compre[0].items()):
        # if "Describe" in ans["prompt"]:
        #     continue
        continue_prompt = False
        for idx in range(1, len(models)):
            if qid in to_compre[idx]:
                if ans["text"].lower() != to_compre[idx][qid]["text"].lower():
                    continue_prompt = True
                    print(f"{qid}: {ans['prompt']}\n\t{models[0]}\n\t\t{ans['text']}\n\t{models[idx]}\n\t\t{to_compre[idx][qid]['text']}")
            else:
                print(f"{qid} not in model {idx}")
        if continue_prompt:
            input("continue?")





def compare_results_mmhal():
    from tqdm import tqdm
    models = [
        "models/Qwen/Qwen-VL-Chat",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_llava_data-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_unk+gqa_idk+docci_idk-_ep1/",
        "<OUTPUT_FOLDER>/qwen-vl/qwen-vl-chat-lora-finetune_idk+llava_data+gqa_idk+docci_train_ep1/",
    ]
    dataset = "MMHal-Bench"
    to_compre = []
    for model in models:
        output_file = f"/{model}/{dataset}/answers/merge.jsonl"
        answers = [json.loads(line) for line in File.open(output_file, "r").readlines()]
        qid2answers = {ans["question_id"]: ans for ans in answers}
        to_compre.append(qid2answers)
    
    refusal_metrics = []
    for model in models:
        output_file = f"/{model}/{dataset}/answers/models_Mistral_Mistral-7B-Instruct-v0.2_refusal01_lave_output.jsonl"
        answers = [json.loads(line) for line in File.open(output_file, "r").readlines()]
        qid2answers = {ans["question_id"]: ans for ans in answers}
        refusal_metrics.append(qid2answers)
    
    accuracy_metrics = []
    for model in models:
        output_file = f"/{model}/{dataset}/answers/models_Mistral_Mistral-7B-Instruct-v0.2_lave_output.jsonl"
        answers = [json.loads(line) for line in File.open(output_file, "r").readlines()]
        qid2answers = {ans["question_id"]: ans for ans in answers}
        accuracy_metrics.append(qid2answers)


    print(f"finish loading....")

    for qid, ans in tqdm(to_compre[0].items()):
        # if "Describe" in ans["prompt"]:
        #     continue
        base_refusal = refusal_metrics[0][str(qid)]["answer_refusal"]
        base_acc = accuracy_metrics[0][str(qid)]["acc"]
        continue_prompt = False
        for idx in range(1, len(models)):
            if qid in to_compre[idx]:
                curr_refusal = refusal_metrics[idx][str(qid)]["answer_refusal"]
                gt = refusal_metrics[idx][str(qid)]["gt"]
                curr_acc = accuracy_metrics[idx][str(qid)]["acc"]
                if curr_refusal != base_refusal or curr_acc != base_acc:
                    continue_prompt = True
                    print(f"\t{models[idx]}\n\t\t{to_compre[idx][qid]['text']}\n\t\tacc: {curr_acc}, refusal: {curr_refusal}")
            else:
                print(f"{qid} not in model {idx}")
        if continue_prompt:
            print(f"{qid}: {ans['prompt']}\n\t{models[0]}\n\t\t{ans['text']})\n\t\tacc: {base_acc}, refusal: {base_refusal}\n\tGT:{gt}")
            input("continue?")

def plot_accuracy_vs_calibration():
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt

    # Define the dataset
    data5 = {
        'Lave Accuracy': [60.26, 69.97, 53.71, 70.1, 57.53, 75.55, 73.52, 75.88, 44.32, 72.46, 75.58, 61.41, 
                        61.63, 67.49, 65.06, 72.05, 70.25, 72.18, 66.9, 41.33, 70.25, 72.81, 69.49, 69.9, 
                        70.95, 66.98, 66.76, 60.24, 66.92, 66.45, 56.98, 60.15, 62.06, 60.14, 66.97, 66.55, 
                        61.46, 38.18, 63.72, 47.41, 73.76, 38.68, 73.45, 73.73, 72.57, 35.7, 74.45, 73.6, 
                        3.33, 49.4, 68.32, 71.68, 68.27, 70.01, 68.47, 71.87, 34.05, 63.34, 67.85, 71.14, 
                        67.89, 68.9, 51.35, 63.47, 38.68, 53.51, 59.42, 59.33, 63.45, 64.82, 38.27, 60.41, 
                        60.53, 65.21, 38.97, 74.32, 50.18, 75.14, 53.55, 75.64, 75.6, 58.37, 53.79],
        'Ours Accuracy': [7.85, 22.17, 3.82, 20.64, 6.4, 35.11, 22.23, 32.05, -7.82, 30.35, 36.88, 9.15, 
                        9.2, 17.99, 15.31, 22.09, 19.96, 21.31, 15.59, -10.29, 22.65, 28.84, 17.23, 24.62, 
                        19.28, 15.05, 16.86, 7.9, 15.17, 17.22, 9.75, 12.32, 14.41, 7.62, 17.54, 17.43, 
                        13.94, -9.43, 21.3, -5.17, 23.02, -9.31, 21.65, 17.52, 28.68, -21.69, 30.26, 18.8, 
                        -27.27, -0.57, 18.14, 22.9, 16.18, 19.34, 19.92, 18, -21.01, 22.86, 23.93, 28.4, 
                        22.63, 25.35, 6.34, 17.7, -6.28, 9.81, 15.4, 13.96, 15.89, 17.63, -9.16, 14.88, 
                        16.61, 17.58, -5.56, 26.51, 4.34, 22.52, 11.76, 36.19, 29.71, 19.68, 11.86],
        'Expected Calibration Score': [0.5463, 0.4746, 0.5347, 0.4492, 0.5346, 0.3403, 0.4849, 0.3983, 0.5501, 0.3528, 
                                    0.3326, 0.5212, 0.551, 0.4758, 0.4267, 0.4781, 0.4703, 0.4902, 0.4957, 0.5509, 
                                    0.4518, 0.3844, 0.4397, 0.3776, 0.4765, 0.5433, 0.5206, 0.5438, 0.5294, 0.5077, 
                                    0.4932, 0.5016, 0.5025, 0.5475, 0.5176, 0.5082, 0.5038, 0.4282, 0.2798, 0.5549, 
                                    0.4602, 0.4469, 0.472, 0.5151, 0.3764, 0.6168, 0.398, 0.5126, 0.2575, 0.1554, 
                                    0.4301, 0.4157, 0.3255, 0.4611, 0.3885, 0.5018, 0.5818, 0.2914, 0.2659, 0.3098, 
                                    0.3159, 0.361, 0.3943, 0.4727, 0.3658, 0.3073, 0.3649, 0.4672, 0.504, 0.505, 
                                    0.4245, 0.4373, 0.3736, 0.5084, 0.3995, 0.4545, 0.403, 0.4763, 0.3173, 0.3609, 
                                    0.4311, 0.3087, 0.3481]
    }

    df5 = pd.DataFrame(data5)

    # Extract columns
    X5 = df5['Lave Accuracy']
    Y5 = df5['Expected Calibration Score']
    Z5 = df5['Ours Accuracy']

    # Define a function to fit and plot with polyfit, with specified colors and correct labels
    def plot_polyfit_legend(X, Y, xlabel, ylabel, color):
        # Scatter plot
        plt.scatter(X, Y, alpha=0.7, color=color, label='Actual data point')
        # Fit line using polyfit
        p = np.poly1d(np.polyfit(X, Y, 1))
        plt.plot(np.unique(X), p(np.unique(X)), color=color, lw=2, label='Fitted line')
        plt.xlabel(xlabel, fontsize=14)
        plt.ylabel(ylabel, fontsize=14)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=1, fontsize=12)

    # Plot Conf-weighted Acc. vs Lave Acc. with specified colors
    plt.figure(figsize=(8, 6))
    plot_polyfit_legend(X5, Z5, r'$Lave_{idk}$ Acc.', 'Conf-weighted Acc.', '#699fc9')
    plt.tight_layout()
    plt.savefig("./ours_vs_lave.png")

    # Plot Conf-weighted Acc. and Lave Acc. vs Expected Calibration Error with specified colors
    def plot_polyfit(X, Y, xlabel, ylabel, color, label):
        # Scatter plot
        plt.scatter(X, Y, alpha=0.7, color=color, label=f'{label} actual data point')
        # Fit line using polyfit
        p = np.poly1d(np.polyfit(X, Y, 1))
        plt.plot(np.unique(X), p(np.unique(X)), color=color, lw=2, label=f'{label} fitted line')
        plt.xlabel(xlabel, fontsize=14)
        plt.ylabel(ylabel, fontsize=14)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2, fontsize=12)

    plt.figure(figsize=(8, 6))
    plot_polyfit(Y5, Z5, 'Expected Calibration Error', 'Acc.', '#699fc9', 'Conf-weighted Acc.')
    plot_polyfit(Y5, X5, 'Expected Calibration Error', 'Acc.', '#f96653', r'$Lave_{idk}$ Acc.')
    plt.tight_layout()
    plt.savefig("./accuracy_vs_calibration.png")


if __name__ == "__main__":
    from fire import Fire
    Fire()